{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sklearn compatible Exponentiated Gradient Reduction\n",
    "\n",
    "Exponentiated gradient reduction is an in-processing technique that reduces fair classification to a sequence of cost-sensitive classification problems, returning a randomized classifier with the lowest empirical error subject to \n",
    "fair classification constraints. The code for exponentiated gradient reduction wraps the source class \n",
    "`fairlearn.reductions.ExponentiatedGradient` available in the https://github.com/fairlearn/fairlearn library,\n",
    "licensed under the MIT Licencse, Copyright Microsoft Corporation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=FutureWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.compose import make_column_transformer\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "from aif360.sklearn.inprocessing import ExponentiatedGradientReduction\n",
    "\n",
    "from aif360.sklearn.datasets import fetch_adult\n",
    "from aif360.sklearn.metrics import average_odds_error"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Datasets are formatted as separate `X` (# samples x # features) and `y` (# samples x # labels) DataFrames. The index of each DataFrame contains protected attribute values per sample. Datasets may also load a `sample_weight` object to be used with certain algorithms/metrics. All of this makes it so that aif360 is compatible with scikit-learn objects.\n",
    "\n",
    "For example, we can easily load the Adult dataset from UCI with the following line:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Non-white</th>\n",
       "      <th>Male</th>\n",
       "      <td>25.0</td>\n",
       "      <td>Private</td>\n",
       "      <td>11th</td>\n",
       "      <td>7.0</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"2\" valign=\"top\">White</th>\n",
       "      <th>Male</th>\n",
       "      <td>38.0</td>\n",
       "      <td>Private</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>50.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>28.0</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Non-white</th>\n",
       "      <th>Male</th>\n",
       "      <td>44.0</td>\n",
       "      <td>Private</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10.0</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>White</th>\n",
       "      <th>Male</th>\n",
       "      <td>34.0</td>\n",
       "      <td>Private</td>\n",
       "      <td>10th</td>\n",
       "      <td>6.0</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Other-service</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>United-States</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 age  workclass     education  education-num  \\\n",
       "race      sex                                                  \n",
       "Non-white Male  25.0    Private          11th            7.0   \n",
       "White     Male  38.0    Private       HS-grad            9.0   \n",
       "          Male  28.0  Local-gov    Assoc-acdm           12.0   \n",
       "Non-white Male  44.0    Private  Some-college           10.0   \n",
       "White     Male  34.0    Private          10th            6.0   \n",
       "\n",
       "                    marital-status         occupation   relationship   race  \\\n",
       "race      sex                                                                 \n",
       "Non-white Male       Never-married  Machine-op-inspct      Own-child  Black   \n",
       "White     Male  Married-civ-spouse    Farming-fishing        Husband  White   \n",
       "          Male  Married-civ-spouse    Protective-serv        Husband  White   \n",
       "Non-white Male  Married-civ-spouse  Machine-op-inspct        Husband  Black   \n",
       "White     Male       Never-married      Other-service  Not-in-family  White   \n",
       "\n",
       "                 sex  capital-gain  capital-loss  hours-per-week  \\\n",
       "race      sex                                                      \n",
       "Non-white Male  Male           0.0           0.0            40.0   \n",
       "White     Male  Male           0.0           0.0            50.0   \n",
       "          Male  Male           0.0           0.0            40.0   \n",
       "Non-white Male  Male        7688.0           0.0            40.0   \n",
       "White     Male  Male           0.0           0.0            30.0   \n",
       "\n",
       "               native-country  \n",
       "race      sex                  \n",
       "Non-white Male  United-States  \n",
       "White     Male  United-States  \n",
       "          Male  United-States  \n",
       "Non-white Male  United-States  \n",
       "White     Male  United-States  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, y, sample_weight = fetch_adult()\n",
    "X.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To match the old version, we also remap the \"race\" feature to \"White\"/\"Non-white\","
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X.race = X.race.cat.set_categories(['Non-white', 'White'], ordered=True).fillna('Non-white')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then map the protected attributes to integers,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "X.index = pd.MultiIndex.from_arrays(X.index.codes, names=X.index.names)\n",
    "y.index = pd.MultiIndex.from_arrays(y.index.codes, names=y.index.names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and the target classes to 0/1,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = pd.Series(y.factorize(sort=True)[0], index=y.index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "split the dataset,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "(X_train, X_test,\n",
    " y_train, y_test) = train_test_split(X, y, train_size=0.7, random_state=1234567)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use sklearn for one-hot encoding for easy reference to columns associated with protected attributes, information necessary for Exponentiated Gradient Reduction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>workclass_Federal-gov</th>\n",
       "      <th>workclass_Local-gov</th>\n",
       "      <th>workclass_Private</th>\n",
       "      <th>workclass_Self-emp-inc</th>\n",
       "      <th>workclass_Self-emp-not-inc</th>\n",
       "      <th>workclass_State-gov</th>\n",
       "      <th>workclass_Without-pay</th>\n",
       "      <th>education_10th</th>\n",
       "      <th>education_11th</th>\n",
       "      <th>education_12th</th>\n",
       "      <th>...</th>\n",
       "      <th>native-country_Thailand</th>\n",
       "      <th>native-country_Trinadad&amp;Tobago</th>\n",
       "      <th>native-country_United-States</th>\n",
       "      <th>native-country_Vietnam</th>\n",
       "      <th>native-country_Yugoslavia</th>\n",
       "      <th>age</th>\n",
       "      <th>education-num</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">1</th>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>58.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>42.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>51.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>26.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1887.0</td>\n",
       "      <td>40.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>44.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>33.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 100 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          workclass_Federal-gov  workclass_Local-gov  workclass_Private  \\\n",
       "race sex                                                                  \n",
       "1    1                      0.0                  0.0                0.0   \n",
       "     0                      0.0                  0.0                0.0   \n",
       "     1                      0.0                  0.0                1.0   \n",
       "     1                      0.0                  0.0                1.0   \n",
       "     1                      0.0                  0.0                1.0   \n",
       "\n",
       "          workclass_Self-emp-inc  workclass_Self-emp-not-inc  \\\n",
       "race sex                                                       \n",
       "1    1                       0.0                         1.0   \n",
       "     0                       0.0                         1.0   \n",
       "     1                       0.0                         0.0   \n",
       "     1                       0.0                         0.0   \n",
       "     1                       0.0                         0.0   \n",
       "\n",
       "          workclass_State-gov  workclass_Without-pay  education_10th  \\\n",
       "race sex                                                               \n",
       "1    1                    0.0                    0.0             0.0   \n",
       "     0                    0.0                    0.0             0.0   \n",
       "     1                    0.0                    0.0             0.0   \n",
       "     1                    0.0                    0.0             0.0   \n",
       "     1                    0.0                    0.0             1.0   \n",
       "\n",
       "          education_11th  education_12th  ...  native-country_Thailand  \\\n",
       "race sex                                  ...                            \n",
       "1    1               0.0             0.0  ...                      0.0   \n",
       "     0               0.0             0.0  ...                      0.0   \n",
       "     1               0.0             0.0  ...                      0.0   \n",
       "     1               0.0             0.0  ...                      0.0   \n",
       "     1               0.0             0.0  ...                      0.0   \n",
       "\n",
       "          native-country_Trinadad&Tobago  native-country_United-States  \\\n",
       "race sex                                                                 \n",
       "1    1                               0.0                           1.0   \n",
       "     0                               0.0                           0.0   \n",
       "     1                               0.0                           1.0   \n",
       "     1                               0.0                           0.0   \n",
       "     1                               0.0                           1.0   \n",
       "\n",
       "          native-country_Vietnam  native-country_Yugoslavia   age  \\\n",
       "race sex                                                            \n",
       "1    1                       0.0                        0.0  58.0   \n",
       "     0                       0.0                        0.0  51.0   \n",
       "     1                       0.0                        0.0  26.0   \n",
       "     1                       0.0                        0.0  44.0   \n",
       "     1                       0.0                        0.0  33.0   \n",
       "\n",
       "          education-num  capital-gain  capital-loss  hours-per-week  \n",
       "race sex                                                             \n",
       "1    1             11.0           0.0           0.0            42.0  \n",
       "     0             12.0           0.0           0.0            30.0  \n",
       "     1             14.0           0.0        1887.0            40.0  \n",
       "     1              3.0           0.0           0.0            40.0  \n",
       "     1              6.0           0.0           0.0            40.0  \n",
       "\n",
       "[5 rows x 100 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ohe = make_column_transformer(\n",
    "        (OneHotEncoder(sparse_output=False), X_train.dtypes == 'category'),\n",
    "        remainder='passthrough', verbose_feature_names_out=False)\n",
    "X_train  = pd.DataFrame(ohe.fit_transform(X_train), columns=ohe.get_feature_names_out(), index=X_train.index)\n",
    "X_test = pd.DataFrame(ohe.transform(X_test), columns=ohe.get_feature_names_out(), index=X_test.index)\n",
    "\n",
    "X_train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The protected attribute information is also replicated in the labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "race  sex\n",
       "1     1      0\n",
       "      0      1\n",
       "      1      1\n",
       "      1      0\n",
       "      1      0\n",
       "dtype: int64"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the data in this format, we can easily train a scikit-learn model and get predictions for the test data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8460234392275374"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred = LogisticRegression(solver='liblinear').fit(X_train, y_train).predict(X_test)\n",
    "lr_acc = accuracy_score(y_test, y_pred)\n",
    "lr_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can assess how close the predictions are to equality of odds.\n",
    "\n",
    "`average_odds_error()` computes the (unweighted) average of the absolute values of the true positive rate (TPR) difference and false positive rate (FPR) difference, i.e.:\n",
    "\n",
    "$$ \\tfrac{1}{2}\\left(|FPR_{D = \\text{unprivileged}} - FPR_{D = \\text{privileged}}| + |TPR_{D = \\text{unprivileged}} - TPR_{D = \\text{privileged}}|\\right) $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.09335303807799161"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr_aoe_sex = average_odds_error(y_test, y_pred, prot_attr='sex')\n",
    "lr_aoe_sex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.06751597777565721"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr_aoe_race = average_odds_error(y_test, y_pred, prot_attr='race')\n",
    "lr_aoe_race"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exponentiated Gradient Reduction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Choose a base model for the randomized classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator = LogisticRegression(solver='liblinear')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Determine the columns associated with the protected attribute(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "prot_attr_cols = [colname for colname in X_train if \"sex\" in colname or \"race\" in colname]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the randomized classifier and observe test accuracy. Other options for `constraints` include \"DemographicParity\", \"TruePositiveRateParity\", \"FalsePositiveRateParity\", and \"ErrorRateParity\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.834303825458834\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(0) #for reproducibility\n",
    "exp_grad_red = ExponentiatedGradientReduction(prot_attr=prot_attr_cols,\n",
    "                                              estimator=estimator,\n",
    "                                              constraints=\"EqualizedOdds\",\n",
    "                                              drop_prot_attr=False)\n",
    "exp_grad_red.fit(X_train, y_train)\n",
    "egr_acc = exp_grad_red.score(X_test, y_test)\n",
    "print(egr_acc)\n",
    "\n",
    "# Check for that accuracy is comparable\n",
    "assert abs(lr_acc-egr_acc)<=0.03"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.02361168550972803\n"
     ]
    }
   ],
   "source": [
    "egr_aoe_sex = average_odds_error(y_test, exp_grad_red.predict(X_test), prot_attr='sex')\n",
    "print(egr_aoe_sex)\n",
    "\n",
    "# Check for improvement in average odds error for sex\n",
    "assert egr_aoe_sex<lr_aoe_sex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.024975550258025947\n"
     ]
    }
   ],
   "source": [
    "egr_aoe_race = average_odds_error(y_test, exp_grad_red.predict(X_test), prot_attr='race')\n",
    "print(egr_aoe_race)\n",
    "\n",
    "# Check for improvement in average odds error for race\n",
    "assert egr_aoe_race<lr_aoe_race"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Number of calls made to base model algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "29"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "exp_grad_red.model_.n_oracle_calls_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Maximum calls permitted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "exp_grad_red.max_iter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead of passing in a string value for `constraints`, we can also pass a `fairlearn.reductions.moment` object. You could use a predefined moment as we do below or create a custom moment using the fairlearn library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.834303825458834"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import fairlearn.reductions as red\n",
    "\n",
    "np.random.seed(0) #need for reproducibility\n",
    "exp_grad_red2 = ExponentiatedGradientReduction(prot_attr=prot_attr_cols,\n",
    "                                               estimator=estimator,\n",
    "                                               constraints=red.EqualizedOdds(),\n",
    "                                               drop_prot_attr=False)\n",
    "exp_grad_red2.fit(X_train, y_train)\n",
    "exp_grad_red2.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.02361168550972803"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "average_odds_error(y_test, exp_grad_red2.predict(X_test), prot_attr='sex')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.024975550258025947"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "average_odds_error(y_test, exp_grad_red2.predict(X_test), prot_attr='race')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('aif360')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "d0c5ced7753e77a483fec8ff7063075635521cce6e0bd54998c8f174742209dd"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
